{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "machine_shape": "hm"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU",
    "gpuClass": "standard"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# 🖌️ **Finetuning Taiyi-Stable-Diffusion Colab Example**\n",
        "\n",
        "#####based on https://huggingface.co/IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-EN-v0.1\n"
      ],
      "metadata": {
        "id": "-GisYq7cG41a"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Installing fengshen framework"
      ],
      "metadata": {
        "id": "twrdGg5zaY0m"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from IPython.display import clear_output\n",
        "\n",
        "!pip install pytorch_lightning\n",
        "!pip install transformers\n",
        "!pip install deepspeed\n",
        "!pip install diffusers\n",
        "!pip install datasets\n",
        "!pip install accelerate\n",
        "\n",
        "!git clone https://github.com/IDEA-CCNL/Fengshenbang-LM\n",
        "\n",
        "clear_output()\n",
        "print(\"Done!\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Y24PHP7dG4gj",
        "outputId": "8c444a57-dfc8-4e6e-84f6-f7cbdde03c68"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Done!\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import os\n",
        "\n",
        "# 切换工作路径\n",
        "os.chdir('/content/Fengshenbang-LM')\n",
        "print(os.getcwd())"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "lwZ2CAgkLgda",
        "outputId": "d2471d59-c1a5-43d1-fb19-055bf8dfde2c"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "/content/Fengshenbang-LM\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Building modules"
      ],
      "metadata": {
        "id": "EMYaGij5acpb"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CnXybs4VFJnz"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import torch\n",
        "import argparse\n",
        "from pytorch_lightning import (\n",
        "    LightningModule,\n",
        "    Trainer,\n",
        ")\n",
        "from pytorch_lightning.callbacks import (\n",
        "    LearningRateMonitor,\n",
        ")\n",
        "from fengshen.data.universal_datamodule import UniversalDataModule\n",
        "from fengshen.models.model_utils import (\n",
        "    add_module_args,\n",
        "    configure_optimizers,\n",
        "    get_total_steps,\n",
        ")\n",
        "from fengshen.utils.universal_checkpoint import UniversalCheckpoint\n",
        "from transformers import BertTokenizer, BertModel\n",
        "from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel\n",
        "from torch.nn import functional as F\n",
        "from fengshen.data.taiyi_stable_diffusion_datasets.taiyi_datasets import add_data_args, load_data\n",
        "from torchvision import transforms\n",
        "from PIL import Image\n",
        "from torch.utils.data._utils.collate import default_collate\n",
        "\n",
        "\n",
        "class Collator():\n",
        "    def __init__(self, args, tokenizer):\n",
        "        self.image_transforms = transforms.Compose(\n",
        "            [\n",
        "                transforms.Resize(\n",
        "                    args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),\n",
        "                transforms.CenterCrop(\n",
        "                    args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),\n",
        "                transforms.ToTensor(),\n",
        "                transforms.Normalize([0.5], [0.5]),\n",
        "            ]\n",
        "        )\n",
        "        self.tokenizer = tokenizer\n",
        "\n",
        "    def __call__(self, inputs):\n",
        "        examples = []\n",
        "        max_length = min(max([len(i['caption']) for i in inputs]), 512)\n",
        "        for i in inputs:\n",
        "            example = {}\n",
        "            instance_image = Image.open(i['img_path'])\n",
        "            if not instance_image.mode == \"RGB\":\n",
        "                instance_image = instance_image.convert(\"RGB\")\n",
        "            example[\"pixel_values\"] = self.image_transforms(instance_image)\n",
        "            example[\"input_ids\"] = self.tokenizer(\n",
        "                i['caption'],\n",
        "                padding=\"max_length\",\n",
        "                truncation=True,\n",
        "                max_length=max_length,\n",
        "                return_tensors='pt',\n",
        "            )['input_ids'][0]\n",
        "            examples.append(example)\n",
        "        return default_collate(examples)\n",
        "\n",
        "class StableDiffusion(LightningModule):\n",
        "    @staticmethod\n",
        "    def add_module_specific_args(parent_parser):\n",
        "        parser = parent_parser.add_argument_group('Taiyi Stable Diffusion Module')\n",
        "        parser.add_argument('--freeze_unet', action='store_true', default=False)\n",
        "        parser.add_argument('--freeze_text_encoder', action='store_true', default=False)\n",
        "        return parent_parser\n",
        "\n",
        "    def __init__(self, args):\n",
        "        super().__init__()\n",
        "        self.tokenizer = BertTokenizer.from_pretrained(\n",
        "            args.model_path, subfolder=\"tokenizer\")\n",
        "        self.text_encoder = BertModel.from_pretrained(\n",
        "            args.model_path, subfolder=\"text_encoder\")  # load from taiyi_finetune-v0\n",
        "        self.vae = AutoencoderKL.from_pretrained(\n",
        "            args.model_path, subfolder=\"vae\")\n",
        "        self.unet = UNet2DConditionModel.from_pretrained(\n",
        "            args.model_path, subfolder=\"unet\")\n",
        "        # TODO: 使用xformers配合deepspeed速度反而有下降(待确认\n",
        "        self.unet.set_use_memory_efficient_attention_xformers(False)\n",
        "\n",
        "        self.noise_scheduler = DDPMScheduler(\n",
        "            beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\", num_train_timesteps=1000\n",
        "        )\n",
        "\n",
        "        for param in self.vae.parameters():\n",
        "            param.requires_grad = False\n",
        "\n",
        "        if args.freeze_text_encoder:\n",
        "            for param in self.text_encoder.parameters():\n",
        "                param.requires_grad = False\n",
        "\n",
        "        if args.freeze_unet:\n",
        "            for param in self.unet.parameters():\n",
        "                param.requires_grad = False\n",
        "\n",
        "        self.save_hyperparameters(args)\n",
        "\n",
        "    def setup(self, stage) -> None:\n",
        "        if stage == 'fit':\n",
        "            self.total_steps = get_total_steps(self.trainer, self.hparams)\n",
        "            print('Total steps: {}' .format(self.total_steps))\n",
        "\n",
        "    def configure_optimizers(self):\n",
        "        return configure_optimizers(self)\n",
        "\n",
        "    def training_step(self, batch, batch_idx):\n",
        "        self.text_encoder.train()\n",
        "\n",
        "        latents = self.vae.encode(batch[\"pixel_values\"]).latent_dist.sample()\n",
        "        latents = latents * 0.18215\n",
        "\n",
        "        # Sample noise that we'll add to the latents\n",
        "        noise = torch.randn(latents.shape).to(latents.device)\n",
        "        noise = noise.to(dtype=self.unet.dtype)\n",
        "        bsz = latents.shape[0]\n",
        "        # Sample a random timestep for each image\n",
        "        timesteps = torch.randint(\n",
        "            0, self.noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)\n",
        "        timesteps = timesteps.long()\n",
        "        # Add noise to the latents according to the noise magnitude at each timestep\n",
        "        # (this is the forward diffusion process)\n",
        "\n",
        "        noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)\n",
        "        noisy_latents = noisy_latents.to(dtype=self.unet.dtype)\n",
        "\n",
        "        # Get the text embedding for conditioning\n",
        "        encoder_hidden_states = self.text_encoder(batch[\"input_ids\"])[0]\n",
        "\n",
        "        # Predict the noise residual\n",
        "        noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample\n",
        "\n",
        "        loss = F.mse_loss(noise_pred, noise, reduction=\"none\").mean([1, 2, 3]).mean()\n",
        "        self.log(\"train_loss\", loss.item(),  on_epoch=False, prog_bar=True, logger=True)\n",
        "\n",
        "        if self.trainer.global_rank == 0 and self.global_step == 100:\n",
        "            # 打印显存占用\n",
        "            from fengshen.utils.utils import report_memory\n",
        "            report_memory('stable diffusion')\n",
        "\n",
        "        return {\"loss\": loss}\n",
        "\n",
        "    def on_save_checkpoint(self, checkpoint) -> None:\n",
        "        if self.trainer.global_rank == 0:\n",
        "            print('saving model...')\n",
        "            pipeline = StableDiffusionPipeline.from_pretrained(\n",
        "                self.hparams.model_path,\n",
        "                text_encoder=self.text_encoder,\n",
        "                tokenizer=self.tokenizer,\n",
        "                unet=self.unet)\n",
        "            self.trainer.current_epoch\n",
        "            pipeline.save_pretrained(os.path.join(\n",
        "                args.default_root_dir, f'hf_out_{self.trainer.current_epoch}_{self.trainer.global_step}'))\n",
        "\n",
        "    def on_load_checkpoint(self, checkpoint) -> None:\n",
        "        # 兼容低版本lightning，低版本lightning从ckpt起来时steps数会被重置为0\n",
        "        global_step_offset = checkpoint[\"global_step\"]\n",
        "        if 'global_samples' in checkpoint:\n",
        "            self.consumed_samples = checkpoint['global_samples']\n",
        "        self.trainer.fit_loop.epoch_loop._batches_that_stepped = global_step_offset"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Settings"
      ],
      "metadata": {
        "id": "jN-ATKxi1TUa"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from pprint import pprint\n",
        "\n",
        "args_parser = argparse.ArgumentParser()\n",
        "args_parser = add_module_args(args_parser)\n",
        "args_parser = add_data_args(args_parser)\n",
        "args_parser = UniversalDataModule.add_data_specific_args(args_parser)\n",
        "args_parser = Trainer.add_argparse_args(args_parser)\n",
        "args_parser = StableDiffusion.add_module_specific_args(args_parser)\n",
        "args_parser = UniversalCheckpoint.add_argparse_args(args_parser)\n",
        "\n",
        "# 你的数据集，可以参考 https://github.com/IDEA-CCNL/Fengshenbang-LM/tree/main/fengshen/examples/finetune_taiyi_stable_diffusion 的demo_dataset的设置\n",
        "your_dataset_path = '/content/Fengshenbang-LM/fengshen/examples/finetune_taiyi_stable_diffusion/demo_dataset' #@param {type:\"string\"}\n",
        "# 默认为下载huggingface上的模型\n",
        "your_model_path =  'IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1' #@param {type:\"string\"}\n",
        "train_batch_size = '1' #@param {type:\"string\"}\n",
        "\n",
        "message = [\n",
        "    '--datasets_path', your_dataset_path,\n",
        "    '--datasets_type', 'txt',\n",
        "    '--model_path', your_model_path,\n",
        "    '--train_batchsize', train_batch_size,\n",
        "    '--accelerator', 'gpu',\n",
        "    # '--strategy', 'deepspeed',\n",
        "    '--precision', '16',\n",
        "]\n",
        "\n",
        "args = args_parser.parse_args(args=message)\n",
        "\n",
        "pprint(vars(args), width = 230)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "KFnQFiQ_1S8w",
        "outputId": "0802adc1-2b62-4557-96aa-796b5a2ca535"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "{'accelerator': 'gpu',\n",
            " 'accumulate_grad_batches': None,\n",
            " 'adam_beta1': 0.9,\n",
            " 'adam_beta2': 0.999,\n",
            " 'adam_epsilon': 1e-08,\n",
            " 'amp_backend': None,\n",
            " 'amp_level': None,\n",
            " 'auto_lr_find': False,\n",
            " 'auto_scale_batch_size': False,\n",
            " 'auto_select_gpus': None,\n",
            " 'benchmark': None,\n",
            " 'center_crop': False,\n",
            " 'check_val_every_n_epoch': 1,\n",
            " 'dataloader_workers': 2,\n",
            " 'datasets_name': None,\n",
            " 'datasets_path': ['/content/Fengshenbang-LM/fengshen/examples/finetune_taiyi_stable_diffusion/demo_dataset'],\n",
            " 'datasets_type': ['txt'],\n",
            " 'default_root_dir': None,\n",
            " 'detect_anomaly': False,\n",
            " 'devices': None,\n",
            " 'enable_checkpointing': True,\n",
            " 'enable_model_summary': True,\n",
            " 'enable_progress_bar': True,\n",
            " 'every_n_epochs': None,\n",
            " 'every_n_train_steps': None,\n",
            " 'fast_dev_run': False,\n",
            " 'filename': 'model-ep{epoch:02d}-st{step:d}',\n",
            " 'freeze_text_encoder': False,\n",
            " 'freeze_unet': False,\n",
            " 'gpus': None,\n",
            " 'gradient_clip_algorithm': None,\n",
            " 'gradient_clip_val': None,\n",
            " 'inference_mode': True,\n",
            " 'ipus': None,\n",
            " 'learning_rate': 5e-05,\n",
            " 'limit_predict_batches': None,\n",
            " 'limit_test_batches': None,\n",
            " 'limit_train_batches': None,\n",
            " 'limit_val_batches': None,\n",
            " 'load_ckpt_path': './ckpt/',\n",
            " 'log_every_n_steps': 50,\n",
            " 'logger': True,\n",
            " 'lr_decay_ratio': 1.0,\n",
            " 'lr_decay_steps': 0,\n",
            " 'max_epochs': None,\n",
            " 'max_steps': -1,\n",
            " 'max_time': None,\n",
            " 'min_epochs': None,\n",
            " 'min_learning_rate': 1e-07,\n",
            " 'min_steps': None,\n",
            " 'mode': 'max',\n",
            " 'model_path': 'IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1',\n",
            " 'monitor': 'step',\n",
            " 'move_metrics_to_cpu': False,\n",
            " 'multiple_trainloader_mode': 'max_size_cycle',\n",
            " 'num_nodes': 1,\n",
            " 'num_processes': None,\n",
            " 'num_sanity_val_steps': 2,\n",
            " 'num_workers': 8,\n",
            " 'overfit_batches': 0.0,\n",
            " 'plugins': None,\n",
            " 'precision': 16,\n",
            " 'profiler': None,\n",
            " 'raw_file_type': 'json',\n",
            " 'reload_dataloaders_every_n_epochs': 0,\n",
            " 'replace_sampler_ddp': True,\n",
            " 'resolution': 512,\n",
            " 'resume_from_checkpoint': None,\n",
            " 'sampler_type': 'random',\n",
            " 'save_ckpt_path': './ckpt/',\n",
            " 'save_last': False,\n",
            " 'save_on_train_epoch_end': None,\n",
            " 'save_top_k': 10,\n",
            " 'save_weights_only': False,\n",
            " 'scheduler_type': 'polynomial',\n",
            " 'strategy': None,\n",
            " 'sync_batchnorm': False,\n",
            " 'test_batchsize': 16,\n",
            " 'test_datasets_field': 'test',\n",
            " 'test_file': None,\n",
            " 'thres': 0.2,\n",
            " 'tpu_cores': None,\n",
            " 'track_grad_norm': -1,\n",
            " 'train_batchsize': 1,\n",
            " 'train_datasets_field': 'train',\n",
            " 'train_file': None,\n",
            " 'val_batchsize': 16,\n",
            " 'val_check_interval': None,\n",
            " 'val_datasets_field': 'validation',\n",
            " 'val_file': None,\n",
            " 'warmup_ratio': 0.1,\n",
            " 'warmup_steps': 0,\n",
            " 'weight_decay': 0.1}\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Start training"
      ],
      "metadata": {
        "id": "sgSAEhHoagek"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!nvidia-smi\n",
        "!cat /proc/meminfo"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "yALlfBnj4AUF",
        "outputId": "528bd7a5-1c9a-48e5-d92a-471ac9774dde"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Mon Feb 13 05:26:48 2023       \n",
            "+-----------------------------------------------------------------------------+\n",
            "| NVIDIA-SMI 510.47.03    Driver Version: 510.47.03    CUDA Version: 11.6     |\n",
            "|-------------------------------+----------------------+----------------------+\n",
            "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
            "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
            "|                               |                      |               MIG M. |\n",
            "|===============================+======================+======================|\n",
            "|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |\n",
            "| N/A   71C    P0    32W /  70W |      3MiB / 15360MiB |      0%      Default |\n",
            "|                               |                      |                  N/A |\n",
            "+-------------------------------+----------------------+----------------------+\n",
            "                                                                               \n",
            "+-----------------------------------------------------------------------------+\n",
            "| Processes:                                                                  |\n",
            "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
            "|        ID   ID                                                   Usage      |\n",
            "|=============================================================================|\n",
            "|  No running processes found                                                 |\n",
            "+-----------------------------------------------------------------------------+\n",
            "MemTotal:       26690612 kB\n",
            "MemFree:        22642924 kB\n",
            "MemAvailable:   24857828 kB\n",
            "Buffers:           48944 kB\n",
            "Cached:          2435656 kB\n",
            "SwapCached:            0 kB\n",
            "Active:           331644 kB\n",
            "Inactive:        3390024 kB\n",
            "Active(anon):       1568 kB\n",
            "Inactive(anon):  1233664 kB\n",
            "Active(file):     330076 kB\n",
            "Inactive(file):  2156360 kB\n",
            "Unevictable:           0 kB\n",
            "Mlocked:               0 kB\n",
            "SwapTotal:             0 kB\n",
            "SwapFree:              0 kB\n",
            "Dirty:               396 kB\n",
            "Writeback:             0 kB\n",
            "AnonPages:       1237216 kB\n",
            "Mapped:           537236 kB\n",
            "Shmem:              1304 kB\n",
            "KReclaimable:     103288 kB\n",
            "Slab:             147424 kB\n",
            "SReclaimable:     103288 kB\n",
            "SUnreclaim:        44136 kB\n",
            "KernelStack:        5216 kB\n",
            "PageTables:        21332 kB\n",
            "NFS_Unstable:          0 kB\n",
            "Bounce:                0 kB\n",
            "WritebackTmp:          0 kB\n",
            "CommitLimit:    13345304 kB\n",
            "Committed_AS:    3402512 kB\n",
            "VmallocTotal:   34359738367 kB\n",
            "VmallocUsed:       57876 kB\n",
            "VmallocChunk:          0 kB\n",
            "Percpu:             2672 kB\n",
            "HardwareCorrupted:     0 kB\n",
            "AnonHugePages:         0 kB\n",
            "ShmemHugePages:        0 kB\n",
            "ShmemPmdMapped:        0 kB\n",
            "FileHugePages:         0 kB\n",
            "FilePmdMapped:         0 kB\n",
            "CmaTotal:              0 kB\n",
            "CmaFree:               0 kB\n",
            "HugePages_Total:       0\n",
            "HugePages_Free:        0\n",
            "HugePages_Rsvd:        0\n",
            "HugePages_Surp:        0\n",
            "Hugepagesize:       2048 kB\n",
            "Hugetlb:               0 kB\n",
            "DirectMap4k:      484160 kB\n",
            "DirectMap2M:    15241216 kB\n",
            "DirectMap1G:    13631488 kB\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import pytorch_lightning as pl\n",
        "print(pl.__version__)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "s1vid40pLVDF",
        "outputId": "25a92804-0fbd-4004-83fd-80e763286555"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "1.9.1\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "lr_monitor = LearningRateMonitor(logging_interval='step')\n",
        "checkpoint_callback = UniversalCheckpoint(args)\n",
        "\n",
        "trainer = Trainer.from_argparse_args(args,\n",
        "                                         callbacks=[\n",
        "                                             lr_monitor,\n",
        "                                             checkpoint_callback])\n",
        "\n",
        "model = StableDiffusion(args)\n",
        "tokenizer = model.tokenizer\n",
        "\n",
        "datasets = load_data(args, global_rank=trainer.global_rank)\n",
        "collate_fn = Collator(args, tokenizer)\n",
        "\n",
        "datamoule = UniversalDataModule(\n",
        "    tokenizer=tokenizer, collate_fn=collate_fn, args=args, datasets=datasets)\n",
        "\n",
        "trainer.fit(model, datamoule)"
      ],
      "metadata": {
        "id": "b4nSmmNrLVwG"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "It might be OOM, which is caused by low GPU memory in Colab.\n",
        "\n",
        "This notebook proves that our codes can run in our settings."
      ],
      "metadata": {
        "id": "9DnOM7qNbokd"
      }
    }
  ]
}
